import numpy as np
import torch
from torch import nn

from config import args
from utils.tool import get_gradient_norm
from utils.optim import get_optim

import torch.distributed as dist
import torch.nn.parallel
from torch.autograd import Variable

real_label = 1
fake_label = 0

criterion = torch.nn.BCELoss()

def train_step(args, data_iter, delta, net, optim_x, optim_y, Weight_matrix, out_edges, in_edges, rank):
    results = {}
    torch.cuda.set_device(args.device)

    data, label = next(data_iter)
    if len(data) != len(delta):
        data, label = next(data_iter)
  
    # Transfer data tensor to GPU/CPU (sevice)
    data = data.to(args.device)
    label = label.to(args.device)
  
    loss_phi = 0 # phi(theta,z0)
    rho = 0 #E[c(Z,Z0)]
    inner_steps = 0
    
    # stopping criterion
    if 'neada' in args.optim:
        required_err = 1 / (args.outer_step + 1)
    while args.step < args.total_steps:
        optim_x.zero_grad()
        optim_y.zero_grad()
        rho = torch.mean((torch.norm(delta.view(len(data),-1),2,1)**2)) 
        loss_zt = nn.functional.cross_entropy(net(data + delta), label)
        loss_phi = - ( loss_zt - args.gamma * rho)
        loss_phi.backward()

        optim_y.step(Weight_matrix, out_edges, in_edges, args, rank)
        
        delta = comm_delta(delta, Weight_matrix, out_edges, in_edges, rank)

        args.step += 1
        if 'neada' in args.optim:
            # # using both criterion
            # if y_grad_norm ** 2 <= required_err: # required to fix in neada
            #     break
            inner_steps += 1
            if inner_steps >= args.outer_step:
                break
        else:
            inner_steps += 1
            if inner_steps >= args.n_inner:
                break
      
    # running the loss minimizer, using (data + delta)
    optim_x.zero_grad()
    loss_adversarial = nn.functional.cross_entropy(net(data + delta),label)
    loss_adversarial.backward()
  
    optim_x.step(Weight_matrix, out_edges, in_edges, args, rank)
  
    comm_model_param(net, Weight_matrix, out_edges, in_edges, rank)
  
    if 'tiada' in args.optim:
        total_grad_norm_y = optim_y.total_sum
    
    args.step += 1
  
    with torch.no_grad():
        rho = torch.mean((torch.norm(delta.view(len(data),-1),2,1)**2)) 
        total_loss = loss_adversarial -args.gamma * rho

    # record
    results['x_grad_norm'] = get_gradient_norm(net.parameters()).item()
    results['y_grad_norm'] = get_gradient_norm([delta]).item()
    results['classification_loss'] = loss_adversarial.item()
    results['total_loss'] = total_loss.item()
    if 'tiada' in args.optim:
        results['x_total_grad_sum'] = optim_x.total_sum
        results['y_total_grad_sum'] = optim_y.total_sum
    else:
        results['x_total_grad_sum'] = 0
        results['y_total_grad_sum'] = 0
    if 'tiada' in args.optim or 'adagrad' in args.optim:
        results['x_state_sum_sum'] = optim_x.state_sum_sum
        results['y_state_sum_sum'] = optim_y.state_sum_sum
    else:
        results['x_state_sum_sum'] = 0
        results['y_state_sum_sum'] = 0
   
    return results, delta


def flatten_tensors(tensors):
    """
    Flatten the high-latitude tensor
    """
    if len(tensors) == 1:
        return tensors[0].view(-1).clone()
    flat = torch.cat([t.view(-1) for t in tensors], dim=0)
    return flat


def unflatten_tensors(flat, tensors):
    """
    Restore the flattened tensor to a high-latitude tensor according to the tensors' shape
    """
    outputs = []
    offset = 0
    for tensor in tensors:
        numel = tensor.numel()
        outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
        offset += numel
    return tuple(outputs)


def comm_delta(delta, Weight_matrix, out_edges, in_edges, rank):

    delta_buffer = delta.clone().detach_().cuda()
    in_delta = delta_buffer.clone()
    delta_placeholder = delta_buffer.clone()

    """
    Information to be sent
    """
    delta_buffer.data.copy_(delta)
    out_delta = delta_buffer
    """
    Non blocking sending data
    """
    for out_edge in out_edges:
        assert rank == out_edge.src
        weight = Weight_matrix[out_edge.dest, rank]
        dist.broadcast(tensor=out_delta.mul(weight.type(out_delta.dtype)),
                        src=out_edge.src, group=out_edge.process_group, async_op=True)
    """
    Block receiving data
    """
    in_delta.zero_()
    for in_edge in in_edges:
        dist.broadcast(tensor=delta_placeholder, src=in_edge.src, group=in_edge.process_group)
        in_delta.add_(delta_placeholder)
    """
    fusion parameters
    """
    delta.data.mul_(Weight_matrix[rank, rank].type(delta.data.dtype))
    delta.data.add_(in_delta)

    dist.barrier()

    return delta


def comm_model_param(model, Weight_matrix, out_edges, in_edges, rank):

    param_buffer = torch.tensor([]).cuda()
    for p in model.parameters():
        param_buffer = torch.cat([param_buffer, p.clone().detach_().view(-1)])
    in_param = torch.zeros_like(param_buffer)
    placeholder = torch.zeros_like(param_buffer)

    """
    Information to be sent
    """
    out_param = torch.tensor([]).cuda()
    for p in model.parameters():
        out_param = torch.cat([out_param, p.clone().detach_().view(-1)])
    """
    Non blocking sending data
    """
    for out_edge in out_edges:
        assert rank == out_edge.src
        weight = Weight_matrix[out_edge.dest, rank]
        dist.broadcast(tensor=out_param.mul(weight.type(out_param.dtype)),
                        src=out_edge.src, group=out_edge.process_group, async_op=True)
    """
    Block receiving data
    """
    in_param.zero_()
    for in_edge in in_edges:
        dist.broadcast(tensor=placeholder, src=in_edge.src, group=in_edge.process_group)
        in_param.add_(placeholder)
    """
    fusion parameters
    """
    for p, r in zip(model.parameters(), unflatten_tensors(in_param, model.parameters())):
        p.data.mul_(Weight_matrix[rank, rank].type(p.data.dtype))
        p.data.add_(r)

    dist.barrier()

    return
    